import json
import os
import random
import time

from absl import app
from absl import flags
import numpy as np
import tensorflow as tf

from algorithms import adam
from algorithms import em
from algorithms import greedy
from algorithms import col_norm_sample
from algorithms import svd
from algorithms import svd_w

# omp is excluded since it is basically the same loss as greedy with much worse run time
# sample is excluded due to poor performance
ALGOS = ['adam', 'em', 'greedy', 'sample', 'svd', 'svd_w', 'svd_w+em']
flags.DEFINE_integer('seed', 2023, 'Random seed')
flags.DEFINE_integer('rank', 10, 'Rank')
flags.DEFINE_integer('fisher_rank', 1, 'Rank for approximating Fisher matrix')
flags.DEFINE_integer('trials', 20, 'Number of trials')
flags.DEFINE_string('fisher_file', 'fisher.npy', 'File with Fisher matrix')
flags.DEFINE_string('weight_file', 'weight.npy', 'File with weight matrix')
flags.DEFINE_string('output_file', 'results.json', 'File to write results')
flags.DEFINE_enum('algo', 'svd_w', ALGOS, 'Algorithm to use')
FLAGS = flags.FLAGS
logger = tf.get_logger()


def weighted_loss(y_true, y_pred, weight):
  sqrt_weight = np.sqrt(weight)
  diff = sqrt_weight * (y_true - y_pred)
  return np.linalg.norm(diff)**2


def run(algo) -> None:
  logger.info('Loading matrices...')
  fisher = np.load(FLAGS.fisher_file)
  weight = np.load(FLAGS.weight_file)
  logger.info('Loaded Fisher matrix of size %d x %d', *fisher.shape)
  logger.info('Loaded weight matrix of size %d x %d', *weight.shape)

  loss_results = []
  time_results = []
  U_norms = []
  for rank in range(1, FLAGS.rank + 1):
    losses = []
    times = []
    logger.info('Running rank %d...', rank)
    for _ in range(FLAGS.trials):
      start_time = time.time()

      if algo == 'adam':
        left_factor, right_factor = adam.weighted_lra(weight, fisher, rank)
        loss = weighted_loss(weight, left_factor @ right_factor, fisher)
      elif algo == 'em':
        left_factor, right_factor = em.weighted_lra(weight, fisher, rank)
        loss = weighted_loss(weight, left_factor @ right_factor, fisher)
      elif algo == 'greedy':
        left_factor, right_factor = greedy.weighted_lra(weight, fisher, rank, omp=False)
        loss = weighted_loss(weight, left_factor @ right_factor, fisher)
      elif algo == 'omp':
        left_factor, right_factor = greedy.weighted_lra(weight, fisher, rank, omp=True)
        loss = weighted_loss(weight, left_factor @ right_factor, fisher)
      elif algo == 'sample':
        left_factor, right_factor = col_norm_sample.weighted_lra(weight, fisher, rank)
        loss = weighted_loss(weight, left_factor @ right_factor, fisher)
      elif algo == 'svd':
        left_factor, right_factor = svd.svd(weight, rank)
        loss = weighted_loss(weight, left_factor @ right_factor, fisher)
      elif algo == 'svd_w':
        inv_fisher_u, inv_fisher_v, weight_u, weight_v = svd_w.weighted_lra(
          weight, fisher, rank, FLAGS.fisher_rank
        )
        low_rank_inv_fisher = inv_fisher_u @ inv_fisher_v
        low_rank = weight_u @ weight_v
        loss = weighted_loss(weight, low_rank_inv_fisher * low_rank, fisher)
      elif algo == 'svd_w+em':
        inv_fisher_u, inv_fisher_v, weight_u, weight_v = svd_w.weighted_lra(
          weight, fisher, rank, fisher_rank=1
        )
        low_rank_inv_fisher = inv_fisher_u @ inv_fisher_v
        low_rank = weight_u @ weight_v
        # seed em with svd_w solution
        left_factor, right_factor = em.weighted_lra(
          weight, fisher, rank, initial_solution=low_rank_inv_fisher * low_rank
        )
        loss = weighted_loss(weight, left_factor @ right_factor, fisher)
      else:
        raise ValueError(f'Invalid algorithm {algo}')
      A = left_factor @ right_factor
      U = A @ np.linalg.pinv(weight)
      U_norms.append(np.linalg.norm(U)**2 / rank)

      end_time = time.time()

      losses.append(loss)
      times.append(end_time - start_time)
    logger.info('Loss: avg %f, std %f', np.mean(losses), np.std(losses) / np.sqrt(FLAGS.trials))
    logger.info('Time: avg %f, std %f', np.mean(times), np.std(times) / np.sqrt(FLAGS.trials))
    loss_results.append(np.mean(losses))
    time_results.append(np.mean(times))
  print('Loss', ','.join([str(x) for x in loss_results]))
  print('Time', ','.join([str(x) for x in time_results]))
  print('Max condition', max(U_norms))
  return loss_results, time_results


def main(argv) -> None:
  del argv
  results = {}
  algos = ALGOS if FLAGS.algo is None else [FLAGS.algo]
  for algo in algos:
    loss_results, time_results = run(algo)
    results[algo] = {'loss': loss_results, 'time': time_results}
  with open(FLAGS.output_file, 'w') as fp:
    json.dump(results, fp)


if __name__ == '__main__':
  app.run(main)